import math
import torch
import torch.nn as nn
import sys
from torch import nn, optim
import time
import numpy as np
import random
import os
import scipy.io as sio
from torch.nn import functional as F
from scipy.io import savemat
import torchvision

seed = 12
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

max_steps = 2000
lr_init = 1.0
codebook_size = 16384
codebook_dim = 8

class Dictionary(nn.Module):
    def __init__(self, codebook_size, codebook_dim):
        super(Dictionary, self).__init__()
        initial = torch.randn(codebook_size, codebook_dim)
        self.embedding = nn.Embedding(codebook_size, codebook_dim)
        self.embedding.weight.data.copy_(initial)
        self.embedding.weight.requires_grad = False
        self.codebook_size = codebook_size
        self.eps = 1e-5
        self.decay = 0.9
        self.cluster_size = torch.nn.Parameter(torch.zeros(self.codebook_size), requires_grad = False)
        self.embed_avg = torch.nn.Parameter(self.embedding.weight.clone(), requires_grad = False)

    def calc_wasserstein_distance(self, z):
        codebook = self.embedding.weight

        N = z.size(0)
        D = z.size(1)
        codebook_size = self.codebook_size

        z_mean = z.mean(0)
        z_covariance = torch.mm((z - torch.mean(z, dim=0, keepdim=True)).t(), z - torch.mean(z, dim=0, keepdim=True))/N
        
        ### compute the mean and covariance of codebook vectors
        c = codebook
        c_mean = c.mean(0)
        c_covariance = torch.mm((c - torch.mean(c, dim=0, keepdim=True)).t(), c - torch.mean(c, dim=0, keepdim=True))/codebook_size

        ### calculation of part1
        part_mean =  torch.sum(torch.multiply(z_mean - c_mean, z_mean - c_mean))

        d_covariance = torch.mm(z_covariance, c_covariance)
        
        ### 1/2 d_covariance
        S, Q = torch.linalg.eigh(d_covariance)
        sqrt_S = torch.sqrt(torch.diag(F.relu(S)) + 1e-8)
        d_sqrt_covariance = torch.mm(torch.mm(Q, sqrt_S), Q.T)

        #############calculation of part2
        part_covariance = F.relu(torch.trace(z_covariance + c_covariance - 2.0 * d_sqrt_covariance))
        wasserstein_loss = torch.sqrt(part_mean + part_covariance + 1e-8)
        return wasserstein_loss


    def cluster_size_ema_update(self, new_cluster_size):
        self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
    
    def embed_avg_ema_update(self, new_embed_avg): 
        self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)

    def weight_update(self, num_tokens):
        n = self.cluster_size.sum()
        smoothed_cluster_size = (
                (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
            )
        #normalize embedding average with smoothed cluster size
        embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
        self.embedding.weight.data.copy_(embed_normalized)  

    def quantize(self, z):
        distance = torch.sum(z.detach().square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
        distance.addmm_(z.detach(), self.embedding.weight.data.T, alpha=-2, beta=1)

        token = torch.argmin(distance, dim=1)
        embed = self.embedding(token)
        quant_error = (embed - z.detach()).square().sum(1).mean()
        onehot_probs = F.one_hot(token, self.codebook_size).type(z.dtype)

        #EMA cluster size           
        self.cluster_size_ema_update(onehot_probs.sum(0))

        #EMA embedding average
        embed_sum = onehot_probs.transpose(0,1) @ z         
        self.embed_avg_ema_update(embed_sum)

        #normalize embed_avg and update weight
        self.weight_update(self.codebook_size)
        return quant_error


    def calc_metrics(self, z):
        distance = torch.sum(z.detach().square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
        distance.addmm_(z.detach(), self.embedding.weight.data.T, alpha=-2, beta=1)

        token = torch.argmin(distance, dim=1) 
        embed = self.embedding(token)

        quant_error = (embed - z.detach()).square().sum(1).mean()
        codebook_histogram = token.bincount(minlength=self.codebook_size).float()
        codebook_usage_counts = (codebook_histogram > 0).float().sum()
        codebook_utilization = codebook_usage_counts.item() / self.codebook_size

        avg_probs = codebook_histogram/codebook_histogram.sum(0)
        codebook_perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        wasserstein_distance = self.calc_wasserstein_distance(z)

        return quant_error, codebook_utilization, codebook_perplexity, wasserstein_distance


Dict = Dictionary(codebook_size, codebook_dim).cuda()
#optimizer = torch.optim.SGD(Dict.embedding.parameters(), lr=lr_init, momentum=0.9)

##### zero-steps
z = torch.randn(1000000, codebook_dim).cuda() + 4.0
quant_error, codebook_utilization, codebook_perplexity, wasserstein_distance = Dict.calc_metrics(z)

for step in range(1, max_steps+1):
    z = torch.randn(50000, codebook_dim).cuda() + 4.0
    quant_error = Dict.quantize(z)

    if step == 1 or step%10 == 0:
        print('train step:{}/{}, quant_error:{:.4f}'.format(step, max_steps, quant_error.item()))
    if step == 1 or step%100 == 0:
        z = torch.randn(1000000, codebook_dim).cuda() + 4.0
        quant_error, codebook_utilization, codebook_perplexity, wasserstein_distance = Dict.calc_metrics(z)

        print('eval step:{}/{}, quant_error:{:.4f}, codebook_utilization:{:.4f}, codebook_perplexity:{:.4f}, wasserstein_distance:{:.4f}'.format(step, max_steps, quant_error.item(), codebook_utilization, codebook_perplexity.item(), wasserstein_distance.item()))

